#rm(list = ls())
library(tidyverse)
── Attaching core tidyverse packages ─────────────────────────────────────────────────────────── tidyverse 2.0.0 ──
✔ dplyr 1.1.2 ✔ readr 2.1.4
✔ forcats 1.0.0 ✔ stringr 1.5.0
✔ ggplot2 3.4.3 ✔ tibble 3.2.1
✔ lubridate 1.9.2 ✔ tidyr 1.3.0
✔ purrr 1.0.2 ── Conflicts ───────────────────────────────────────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag() masks stats::lag()
ℹ Use the ]8;;http://conflicted.r-lib.org/conflicted package]8;; to force all conflicts to become errors
library(future)
library(ggthemes)
set.seed(1245264)
In this notebook we will test the performance of varKode to distinguish species of Stigmaphyllon and figure out the best parameters for training a dataset.
To start, we produced images from different numbers of kmers. We can suppose that shorter kmers will offer lower resolution to resolve species, but they will also create smaller files that require less computation. Here we will test whether images based on longer kmers result in higher accuracy. As an example, here are images produced from 200Mb for the same sample, but different kmer sizes (5-9):
knitr::include_graphics(paste0('images_',5:9,'/S_bannisterioides+S-91_00200000K.png'))
We also used different amounts of data to produce images, since we want to figure out the lowest amount needed to distinguish species. With less data, figures get more noisy since chance plays a bigger role in the observed kmer frequencies. This should be more severe for larger kmer sizes, since each kmer will be more unique in the genome.
For example, images for 5-mer for the same sample as above, for 500Kb and 200Mb:
knitr::include_graphics(paste0('images_6/S_bannisterioides+S-91_00',c('000500','200000'),'K.png'))
The same, but for 8-mers:
knitr::include_graphics(paste0('images_8/S_bannisterioides+S-91_00',c('000500','200000'),'K.png'))
Now that we understand the differences between images, let’s understand the effect in accuracy. We previously trained CNN models to recognize images for a combination of kmer sizes and amount of data, with 10 replicates for each combination. In each replicate, we kept 3 randomly chosen samples per species as a validation set and checked the accuracy of the trained model in guessing the species of these samples, for different amounts of data used for the validation sample. What we want is to find:
1 - The lowest kmer size to produce high accuracy
2 - The lowest amount of data needed
3 - Whether the amount of data used for training and for querying must be similar.
The results of these simulations were saved as a csv table, let’s load it (ignoring the first, index column):
df = read_csv('kmerSize_VS_bp.csv')[-1]
New names:Rows: 4500 Columns: 11── Column specification ───────────────────────────────────────────────────────────────────────────────────────────
Delimiter: ","
chr (3): bp_training, samples_training, samples_valid
dbl (8): ...1, kmer_size, replicate, bp_valid, n_samp_training, n_samp_valid, valid_loss, valid_acc
ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
df
Now let’s make sure bp_training and bp_valid are treated as ordered factors for nice plotting:
not_all = as.character(sort(as.integer(unique(df$bp_training[!str_detect(df$bp_training,'\\|',)]))/1e6))
ordered_levels = c(not_all,'all')
df = df %>%
mutate(bp_training = as.character(as.integer(bp_training)/1e6) ) %>%
mutate(bp_training = replace_na(bp_training, 'all')) %>%
mutate(bp_training = factor(bp_training,
levels = ordered_levels,
ordered = TRUE),
bp_valid = factor(as.character(as.integer(bp_valid)/1e6),
levels=ordered_levels,
ordered = TRUE),
kmer_size = factor(as.character(kmer_size),
levels = as.character(sort(unique(kmer_size))),
ordered = TRUE
)
)
Warning: There was 1 warning in `mutate()`.
ℹ In argument: `bp_training = as.character(as.integer(bp_training)/1e+06)`.
Caused by warning:
! NAs introduced by coercion
df
Let’s summarize these results in table so we can put some numbers in the paper:
df %>%
group_by(kmer_size,bp_training) %>%
summarize(min_valid = min(valid_acc),
mean_valid = mean(valid_acc),
max_valid = max(valid_acc))
`summarise()` has grouped output by 'kmer_size'. You can override using the `.groups` argument.
Now we can plot:
kmer_labeller = as_labeller(function(value){
return(paste0('kmer length:',value))
})
ggplot(df) +
geom_jitter(aes(x = bp_training, y = bp_valid, color = valid_acc)) +
scale_color_viridis_c('Validation\naccuracy', option = 'inferno', limits = c(0,1)) +
facet_grid(~kmer_size, labeller = kmer_labeller) +
coord_equal() +
xlab('Data in training images (Mb)') +
ylab('Data in validation images (Mb)')
NA
Now a version with averaged accuracy
p = df %>%
group_by(kmer_size,bp_training,bp_valid) %>%
summarize(valid_acc = mean(valid_acc)) %>%
ggplot(aes(x = bp_training, y = bp_valid, fill = valid_acc)) +
geom_raster() +
#geom_text(aes(label=sprintf(100*valid_acc,fmt='%2.0f')),size=4.5*5/14) +
scale_fill_viridis_c('Average\nvalidation\naccuracy', option = 'magma', limits = c(0,1),labels=scales::percent) +
facet_grid(~kmer_size, labeller = kmer_labeller) +
coord_equal() +
xlab('Data in training images (Mb)') +
ylab('Data in validation images (Mb)') +
theme_few(base_size = 6)
`summarise()` has grouped output by 'kmer_size', 'bp_training'. You can override using the `.groups` argument.
p
dir.create('paper_images')
Warning: 'paper_images' already exists
ggsave(filename = 'kmerlen_vs_accuracy.png',plot =p,device='png',path = 'paper_images',width = 22,height = 5,units = 'cm',dpi = 2400)
means = df %>%
filter(bp_training %in% c('0.5','1','200','all')) %>%
#filter(bp_valid %in% c('50','100','200','all')) %>%
filter(bp_valid %in% c('2','5','10','20','50','100','200')) %>%
group_by(bp_training,kmer_size) %>%
summarise(Int=median(valid_acc))
`summarise()` has grouped output by 'bp_training'. You can override using the `.groups` argument.
df %>%
filter(bp_training %in% c('0.5','1','200','all')) %>%
filter(bp_valid %in% c('2','5','10','20','50','100','200')) %>%
#filter(bp_valid %in% c('50','100','200','all')) %>%
ggplot(aes(x=valid_acc)) +
geom_histogram(aes(x=valid_acc)) +
facet_grid(kmer_size~bp_training) +
geom_vline(data = means, aes(xintercept = Int))
So it seems that the smallest kmer sizes never result in very high accuracy, and the largest kmer sizes result in high accuracy for higher amounts of data, but lower accuracy for lower amounts. It seems that a kmer size of 7 is a good balance, and that training using images of different sizes helps in being more robust to the amount of data used to produce validation images.
As little as 1Mb produces moderately accurate results for kmer size 7 or below.
Can we quantify what is different about images produced with different data amounts? It seems there is larger variation in pixel intensities, probably because of random fluctuations:
images = c(list.files(path='images_5',pattern='.png', recursive = T, full.names = T),
list.files(path='images_6',pattern='.png', recursive = T, full.names = T),
list.files(path='images_7',pattern='.png', recursive = T, full.names = T),
list.files(path='images_8',pattern='.png', recursive = T, full.names = T),
list.files(path='images_9',pattern='.png', recursive = T, full.names = T))
nkmers = function(k){ #from https://bioinfologics.github.io/post/2018/09/17/k-mer-counting-part-i-introduction/
(4^k + (1 - k%%2) * 4^(k/2))/2
}
get_sd = function(path){
k = as.integer(gsub('.+_([0-9])/.+','\\1', path))
taxon = gsub('.+/(.+)\\+.+','\\1', path)
sample = gsub('.+\\+(S-[0-9]+)_.+','\\1', path)
Mbp = as.integer(gsub('.+_([0-9]{8})K.+','\\1', path)) / 1000
x = sort(png::readPNG(path))
x = x[(length(x)-nkmers(k)+1):length(x)]
sd_counts = sd(table(x))
data.frame(k = k, taxon = taxon, sample = sample, Mbp = Mbp, sd_counts=sd_counts)
}
plan(multisession(workers = 4))
df = furrr::future_map_dfr(images,get_sd)
plan(sequential)
df
ggplot(df) +
geom_line(aes(x=Mbp, y=sd_counts,color=sample)) +
facet_wrap(as.factor(k)~.,scales = 'free') +
scale_color_discrete(guide='none') +
scale_y_log10() +
scale_x_log10()
Now we will check the results of using different training parameters: - model pretraining - augmentation (CutMix or MixUp) - Label Smoothing - model architecture - lighting transformations
Let’s read the data and prepare for plotting:
df = read_csv('training_params.csv')[-1] %>%
mutate(bp_valid = factor(as.character(as.integer(bp_valid)/1e6),
levels = sort(unique(bp_valid/1e6)),
ordered = TRUE),
augmentation = ifelse(str_detect(callback,'CutMix'),'CutMix',
ifelse(str_detect(callback,'MixUp'),'MixUp',
'None')
),
augmentation = factor(augmentation, levels = c('None','MixUp','CutMix'),ordered = F),
aug = str_replace(augmentation,'None',''),
lablsmth= ifelse(label_smoothing,
'label Smoothing',
''),
pretr = ifelse(pretrained,
'pretrained',
''
),
transformations = ifelse(trans,
'with_transforms',
''
),
parameters = paste(arch,pretr,lablsmth,aug,transformations,sep=',') %>%
str_replace_all(',{2,}',',') %>%
str_remove_all('^,|,$') %>%
str_replace_all('^$','None') %>%
fct_reorder(valid_acc, mean)
)
New names:Rows: 30240 Columns: 16── Column specification ───────────────────────────────────────────────────────────────────────────────────────────
Delimiter: ","
chr (5): bp_training, samples_training, samples_valid, callback, arch
dbl (8): ...1, kmer_size, replicate, bp_valid, n_samp_training, n_samp_valid, valid_loss, valid_acc
lgl (3): label_smoothing, pretrained, trans
ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
df
NA
NA
NA
Now we can plot the effect of parameters. There are clearly some models that do much better than others:
ggplot(df, aes(x = parameters, y = valid_acc)) +
#geom_boxplot() +
#geom_violin(adjust=1.5) +
geom_jitter(aes(color = bp_valid),height = 0.005) +
scale_color_viridis_d(option='turbo',begin = 0.1, end=0.9) +
#facet_wrap(~bp_valid) +
theme(axis.text.x = element_text(hjust = 1, angle = 45))
Let’s look at the top 20 models:
ggplot(filter(df, parameters %in% tail(levels(df$parameters),20)), aes(x = parameters, y = valid_acc)) +
#geom_boxplot() +
#geom_violin(adjust=1.5) +
geom_jitter(aes(color = bp_valid),height = 0.005) +
scale_color_viridis_d(option='turbo',begin = 0.1, end=0.9) +
#facet_wrap(~bp_valid) +
theme(axis.text.x = element_text(hjust = 1, angle = 45))
Let’s plot by architecture:
p = ggplot(mutate(df, arch = fct_reorder(arch,valid_acc)),
aes(x = arch, y = valid_acc, color=bp_valid)) +
#geom_boxplot() +
#geom_violin(adjust=1.5) +
geom_jitter(aes(color = bp_valid),height = 0.005, size = 0.1, alpha = 0.1, shape = 16) +
stat_summary(fun = mean, geom = 'crossbar', size = 0.05, show.legend=FALSE) +
scale_color_viridis_d(option='turbo',begin = 0.1, end=0.9, name = 'Mbp in validation\nimages',
guide = guide_legend(override.aes = list(alpha = 1))) +
scale_y_continuous(labels = scales::percent, name = 'Validation Accuracy') +
xlab('Model architecture') +
#facet_wrap(~bp_valid) +
theme_few(base_size = 6) +
theme(axis.text.x = element_text(hjust = 1, angle = 45),
legend.key.size = unit(0.2, "cm"))
Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
Please use `linewidth` instead.
p
ggsave(filename = 'architecture.png',plot =p,device='png',path = 'paper_images',width = 5,height = 5,units = 'cm',dpi = 2400)
Now by pretrained:
p = ggplot(mutate(df, pretr = fct_reorder(pretr,valid_acc)),
aes(x = pretr, y = valid_acc, color=bp_valid)) +
#geom_boxplot() +
#geom_violin() +
geom_jitter(aes(color = bp_valid),height = 0.005, size = 0.05, alpha = 0.1, shape = 16) +
stat_summary(fun = mean, geom = 'crossbar', size = 0.05) +
scale_x_discrete(labels = c('pre-trained','random'), name = 'Model pretraining') +
scale_color_viridis_d(option='turbo',begin = 0.1, end=0.9, name = 'Mbp in validation\nimages') +
scale_y_continuous(labels = scales::percent, name = 'Validation Accuracy') +
#facet_wrap(~bp_valid) +
theme_few(base_size = 6) +
theme(axis.text.x = element_text(hjust = 1, angle = 45),
legend.position = 'none'
)
p
ggsave(filename = 'pretraining.png',plot =p,device='png',path = 'paper_images',width = 3,height = 5,units = 'cm',dpi = 2400)
Now by label smoothing:
p = ggplot(mutate(df, lablsmth = fct_reorder(lablsmth,valid_acc)),
aes(x = lablsmth, y = valid_acc, color=bp_valid)) +
#geom_boxplot() +
#geom_violin(adjust=1.5) +
geom_jitter(aes(color = bp_valid),height = 0.005, size = 0.05, alpha = 0.1, shape = 16) +
stat_summary(fun = mean, geom = 'crossbar', size = 0.05) +
scale_x_discrete(labels = c('No','Yes'), name = 'Label smoothing') +
scale_color_viridis_d(option='turbo',begin = 0.1, end=0.9, name = 'Mbp in validation\nimages') +
scale_y_continuous(labels = scales::percent, name = 'Validation Accuracy') +
#facet_wrap(~bp_valid) +
theme_few(base_size = 6) +
theme(axis.text.x = element_text(hjust = 1, angle = 45),
legend.position = 'none'
)
p
ggsave(filename = 'labelsmoothing.png',plot =p,device='png',path = 'paper_images',width = 3,height = 5,units = 'cm',dpi = 2400)
Now by CutMix/MixUp augmentations:
p = ggplot(mutate(df, augmentation = fct_reorder(augmentation,valid_acc)),
aes(x = augmentation, y = valid_acc, color = bp_valid)) +
#geom_boxplot() +
#geom_violin(adjust=1.5) +
geom_jitter(aes(color = bp_valid),height = 0.005, size = 0.05, alpha = 0.1, shape = 16) +
stat_summary(fun = mean, geom = 'crossbar', size = 0.05) +
scale_x_discrete(name = 'Augmentation') +
scale_color_viridis_d(option='turbo',begin = 0.1, end=0.9, name = 'Mbp in validation\nimages') +
scale_y_continuous(labels = scales::percent, name = 'Validation Accuracy') +
theme_few(base_size = 6) +
theme(axis.text.x = element_text(hjust = 1, angle = 45),
legend.position = 'none'
)
p
ggsave(filename = 'augmentations.png',plot =p,device='png',path = 'paper_images',width = 3,height = 5,units = 'cm',dpi = 2400)
Finally,by lighting transforms:
p = ggplot(mutate(df, transformations = fct_reorder(transformations,valid_acc)),
aes(x = transformations, y = valid_acc, color=bp_valid)) +
#geom_boxplot() +
#geom_violin(adjust=1.5) +
geom_jitter(aes(color = bp_valid),height = 0.005, size = 0.05, alpha = 0.1, shape = 16) +
stat_summary(fun = mean, geom = 'crossbar', size = 0.05) +
scale_x_discrete(name = 'Lighting transforms', labels = c('No','Yes')) +
scale_color_viridis_d(option='turbo',begin = 0.1, end=0.9, name = 'Mbp in validation\nimages') +
scale_y_continuous(labels = scales::percent, name = 'Validation Accuracy') +
theme_few(base_size = 6) +
theme(axis.text.x = element_text(hjust = 1, angle = 45),
legend.position = 'none'
)
p
ggsave(filename = 'lighting.png',plot =p,device='png',path = 'paper_images',width = 3,height = 5,units = 'cm',dpi = 2400)
Let’s try a linear model to check which combination is best:
full_model = lm(asin(valid_acc)~arch*trans*pretrained*augmentation*label_smoothing*bp_valid, data = df)
plot(full_model)
reduced_model = step(lm(asin(valid_acc)~1, data = df),
scope = list(lower = formula(asin(valid_acc)~1),
upper = formula(asin(valid_acc)~arch*trans*pretrained*augmentation*label_smoothing*bp_valid)
),
direction = 'forward')
The best model is quite complex with some interactions
reduced_model
Call:
lm(formula = asin(valid_acc) ~ bp_valid + pretrained + arch +
augmentation + label_smoothing + trans + pretrained:arch +
bp_valid:pretrained + bp_valid:arch + arch:augmentation +
bp_valid:augmentation + pretrained:augmentation + augmentation:label_smoothing +
pretrained:label_smoothing + arch:trans + augmentation:trans +
bp_valid:pretrained:arch + pretrained:arch:augmentation +
pretrained:augmentation:label_smoothing + arch:augmentation:trans +
bp_valid:pretrained:augmentation, data = df)
Coefficients:
(Intercept)
6.561e-01
bp_valid.L
3.114e-01
bp_valid.Q
-1.833e-01
bp_valid.C
5.212e-02
bp_valid^4
1.349e-02
bp_valid^5
-2.865e-02
bp_valid^6
1.188e-02
bp_valid^7
9.771e-03
bp_valid^8
-5.580e-03
pretrainedTRUE
6.445e-02
archig_resnext101_32x8d
4.786e-01
archresnet101d
3.779e-01
archresnet18d
3.050e-01
archresnet50
4.784e-01
archresnet50d
3.871e-01
archwide_resnet50_2
5.112e-01
augmentationMixUp
6.400e-02
augmentationCutMix
-2.131e-02
label_smoothingTRUE
3.604e-02
transTRUE
1.098e-02
pretrainedTRUE:archig_resnext101_32x8d
-3.675e-01
pretrainedTRUE:archresnet101d
-3.555e-01
pretrainedTRUE:archresnet18d
-2.542e-01
pretrainedTRUE:archresnet50
-3.452e-01
pretrainedTRUE:archresnet50d
-3.491e-01
pretrainedTRUE:archwide_resnet50_2
-4.532e-01
bp_valid.L:pretrainedTRUE
2.358e-01
bp_valid.Q:pretrainedTRUE
-4.133e-02
bp_valid.C:pretrainedTRUE
-5.603e-02
bp_valid^4:pretrainedTRUE
4.519e-02
bp_valid^5:pretrainedTRUE
-1.260e-02
bp_valid^6:pretrainedTRUE
-1.502e-02
bp_valid^7:pretrainedTRUE
5.130e-03
bp_valid^8:pretrainedTRUE
3.372e-03
bp_valid.L:archig_resnext101_32x8d
-2.926e-02
bp_valid.Q:archig_resnext101_32x8d
-8.490e-03
bp_valid.C:archig_resnext101_32x8d
1.977e-02
bp_valid^4:archig_resnext101_32x8d
-8.721e-03
bp_valid^5:archig_resnext101_32x8d
-2.483e-03
bp_valid^6:archig_resnext101_32x8d
8.711e-03
bp_valid^7:archig_resnext101_32x8d
-1.260e-02
bp_valid^8:archig_resnext101_32x8d
2.109e-04
bp_valid.L:archresnet101d
-1.653e-01
bp_valid.Q:archresnet101d
7.381e-02
bp_valid.C:archresnet101d
9.417e-04
bp_valid^4:archresnet101d
-2.948e-02
bp_valid^5:archresnet101d
2.877e-02
bp_valid^6:archresnet101d
-1.376e-02
bp_valid^7:archresnet101d
-5.532e-03
bp_valid^8:archresnet101d
-4.737e-03
bp_valid.L:archresnet18d
-1.193e-01
bp_valid.Q:archresnet18d
5.106e-02
bp_valid.C:archresnet18d
1.054e-02
bp_valid^4:archresnet18d
-3.777e-02
bp_valid^5:archresnet18d
2.669e-02
bp_valid^6:archresnet18d
-5.637e-03
bp_valid^7:archresnet18d
-1.105e-02
bp_valid^8:archresnet18d
1.530e-03
bp_valid.L:archresnet50
-4.650e-02
bp_valid.Q:archresnet50
-6.761e-03
bp_valid.C:archresnet50
2.939e-02
bp_valid^4:archresnet50
-1.920e-02
bp_valid^5:archresnet50
2.184e-03
bp_valid^6:archresnet50
1.335e-02
bp_valid^7:archresnet50
-1.706e-02
bp_valid^8:archresnet50
2.617e-03
bp_valid.L:archresnet50d
-1.553e-01
bp_valid.Q:archresnet50d
7.842e-02
bp_valid.C:archresnet50d
-4.425e-03
bp_valid^4:archresnet50d
-2.616e-02
bp_valid^5:archresnet50d
2.930e-02
bp_valid^6:archresnet50d
-1.431e-02
bp_valid^7:archresnet50d
-3.372e-03
bp_valid^8:archresnet50d
-5.996e-03
bp_valid.L:archwide_resnet50_2
-3.658e-02
bp_valid.Q:archwide_resnet50_2
1.171e-03
bp_valid.C:archwide_resnet50_2
1.877e-02
bp_valid^4:archwide_resnet50_2
-1.304e-02
bp_valid^5:archwide_resnet50_2
-1.398e-02
bp_valid^6:archwide_resnet50_2
2.189e-02
bp_valid^7:archwide_resnet50_2
-1.898e-02
bp_valid^8:archwide_resnet50_2
-9.812e-04
archig_resnext101_32x8d:augmentationMixUp
-1.766e-02
archresnet101d:augmentationMixUp
5.130e-02
archresnet18d:augmentationMixUp
-9.741e-03
archresnet50:augmentationMixUp
-2.336e-02
archresnet50d:augmentationMixUp
2.380e-02
archwide_resnet50_2:augmentationMixUp
-5.815e-02
archig_resnext101_32x8d:augmentationCutMix
9.295e-02
archresnet101d:augmentationCutMix
1.288e-01
archresnet18d:augmentationCutMix
6.007e-02
archresnet50:augmentationCutMix
7.041e-02
archresnet50d:augmentationCutMix
1.017e-01
archwide_resnet50_2:augmentationCutMix
3.642e-02
bp_valid.L:augmentationMixUp
3.908e-03
bp_valid.Q:augmentationMixUp
-6.958e-03
bp_valid.C:augmentationMixUp
4.087e-03
bp_valid^4:augmentationMixUp
2.047e-04
bp_valid^5:augmentationMixUp
5.062e-03
bp_valid^6:augmentationMixUp
-4.185e-03
bp_valid^7:augmentationMixUp
2.057e-03
bp_valid^8:augmentationMixUp
4.019e-03
bp_valid.L:augmentationCutMix
-1.606e-02
bp_valid.Q:augmentationCutMix
4.800e-03
bp_valid.C:augmentationCutMix
7.649e-05
bp_valid^4:augmentationCutMix
-1.092e-04
bp_valid^5:augmentationCutMix
1.038e-02
bp_valid^6:augmentationCutMix
-7.533e-03
bp_valid^7:augmentationCutMix
2.572e-03
bp_valid^8:augmentationCutMix
4.894e-03
pretrainedTRUE:augmentationMixUp
-5.293e-02
pretrainedTRUE:augmentationCutMix
2.794e-02
augmentationMixUp:label_smoothingTRUE
-5.032e-02
augmentationCutMix:label_smoothingTRUE
-5.159e-02
pretrainedTRUE:label_smoothingTRUE
-5.352e-02
archig_resnext101_32x8d:transTRUE
1.860e-03
archresnet101d:transTRUE
1.813e-03
archresnet18d:transTRUE
1.203e-03
archresnet50:transTRUE
-9.509e-03
archresnet50d:transTRUE
-4.479e-03
archwide_resnet50_2:transTRUE
-1.943e-02
augmentationMixUp:transTRUE
-1.450e-02
augmentationCutMix:transTRUE
-5.982e-03
bp_valid.L:pretrainedTRUE:archig_resnext101_32x8d
-2.637e-02
bp_valid.Q:pretrainedTRUE:archig_resnext101_32x8d
-2.021e-02
bp_valid.C:pretrainedTRUE:archig_resnext101_32x8d
1.837e-02
bp_valid^4:pretrainedTRUE:archig_resnext101_32x8d
-1.841e-02
bp_valid^5:pretrainedTRUE:archig_resnext101_32x8d
2.453e-02
bp_valid^6:pretrainedTRUE:archig_resnext101_32x8d
-9.192e-03
bp_valid^7:pretrainedTRUE:archig_resnext101_32x8d
-3.420e-03
bp_valid^8:pretrainedTRUE:archig_resnext101_32x8d
-1.808e-02
bp_valid.L:pretrainedTRUE:archresnet101d
1.349e-01
bp_valid.Q:pretrainedTRUE:archresnet101d
-7.756e-02
bp_valid.C:pretrainedTRUE:archresnet101d
3.039e-02
bp_valid^4:pretrainedTRUE:archresnet101d
6.843e-03
bp_valid^5:pretrainedTRUE:archresnet101d
-1.434e-02
bp_valid^6:pretrainedTRUE:archresnet101d
-1.204e-03
bp_valid^7:pretrainedTRUE:archresnet101d
-6.390e-03
bp_valid^8:pretrainedTRUE:archresnet101d
1.045e-02
bp_valid.L:pretrainedTRUE:archresnet18d
1.673e-01
bp_valid.Q:pretrainedTRUE:archresnet18d
-7.756e-02
bp_valid.C:pretrainedTRUE:archresnet18d
2.741e-02
bp_valid^4:pretrainedTRUE:archresnet18d
7.421e-04
bp_valid^5:pretrainedTRUE:archresnet18d
-6.923e-04
bp_valid^6:pretrainedTRUE:archresnet18d
-1.453e-02
bp_valid^7:pretrainedTRUE:archresnet18d
8.352e-03
bp_valid^8:pretrainedTRUE:archresnet18d
-4.905e-03
bp_valid.L:pretrainedTRUE:archresnet50
-1.594e-02
bp_valid.Q:pretrainedTRUE:archresnet50
9.914e-03
bp_valid.C:pretrainedTRUE:archresnet50
-1.984e-02
bp_valid^4:pretrainedTRUE:archresnet50
-1.313e-02
bp_valid^5:pretrainedTRUE:archresnet50
3.741e-02
bp_valid^6:pretrainedTRUE:archresnet50
5.299e-04
bp_valid^7:pretrainedTRUE:archresnet50
-3.754e-03
bp_valid^8:pretrainedTRUE:archresnet50
-1.681e-03
bp_valid.L:pretrainedTRUE:archresnet50d
1.205e-01
bp_valid.Q:pretrainedTRUE:archresnet50d
-1.114e-01
bp_valid.C:pretrainedTRUE:archresnet50d
4.579e-02
bp_valid^4:pretrainedTRUE:archresnet50d
9.102e-03
bp_valid^5:pretrainedTRUE:archresnet50d
-3.293e-02
bp_valid^6:pretrainedTRUE:archresnet50d
9.293e-03
bp_valid^7:pretrainedTRUE:archresnet50d
2.861e-02
bp_valid^8:pretrainedTRUE:archresnet50d
8.449e-03
bp_valid.L:pretrainedTRUE:archwide_resnet50_2
-2.731e-02
bp_valid.Q:pretrainedTRUE:archwide_resnet50_2
2.152e-02
bp_valid.C:pretrainedTRUE:archwide_resnet50_2
-1.329e-02
bp_valid^4:pretrainedTRUE:archwide_resnet50_2
-2.174e-02
bp_valid^5:pretrainedTRUE:archwide_resnet50_2
5.888e-02
bp_valid^6:pretrainedTRUE:archwide_resnet50_2
-4.246e-02
bp_valid^7:pretrainedTRUE:archwide_resnet50_2
3.718e-02
bp_valid^8:pretrainedTRUE:archwide_resnet50_2
-3.049e-02
pretrainedTRUE:archig_resnext101_32x8d:augmentationMixUp
6.789e-02
pretrainedTRUE:archresnet101d:augmentationMixUp
-3.619e-02
pretrainedTRUE:archresnet18d:augmentationMixUp
-3.443e-03
pretrainedTRUE:archresnet50:augmentationMixUp
2.548e-02
pretrainedTRUE:archresnet50d:augmentationMixUp
5.615e-04
pretrainedTRUE:archwide_resnet50_2:augmentationMixUp
4.066e-02
pretrainedTRUE:archig_resnext101_32x8d:augmentationCutMix
-9.468e-02
pretrainedTRUE:archresnet101d:augmentationCutMix
-1.122e-01
pretrainedTRUE:archresnet18d:augmentationCutMix
-8.547e-02
pretrainedTRUE:archresnet50:augmentationCutMix
-8.759e-02
pretrainedTRUE:archresnet50d:augmentationCutMix
-6.982e-02
pretrainedTRUE:archwide_resnet50_2:augmentationCutMix
-6.689e-02
pretrainedTRUE:augmentationMixUp:label_smoothingTRUE
5.111e-02
pretrainedTRUE:augmentationCutMix:label_smoothingTRUE
5.191e-02
archig_resnext101_32x8d:augmentationMixUp:transTRUE
-5.656e-03
archresnet101d:augmentationMixUp:transTRUE
2.068e-02
archresnet18d:augmentationMixUp:transTRUE
1.244e-02
archresnet50:augmentationMixUp:transTRUE
9.801e-03
archresnet50d:augmentationMixUp:transTRUE
-9.574e-04
archwide_resnet50_2:augmentationMixUp:transTRUE
1.794e-02
archig_resnext101_32x8d:augmentationCutMix:transTRUE
8.469e-03
archresnet101d:augmentationCutMix:transTRUE
-8.739e-03
archresnet18d:augmentationCutMix:transTRUE
-2.055e-02
archresnet50:augmentationCutMix:transTRUE
1.405e-02
archresnet50d:augmentationCutMix:transTRUE
-1.698e-04
archwide_resnet50_2:augmentationCutMix:transTRUE
2.349e-03
bp_valid.L:pretrainedTRUE:augmentationMixUp
-5.862e-03
bp_valid.Q:pretrainedTRUE:augmentationMixUp
-9.938e-03
bp_valid.C:pretrainedTRUE:augmentationMixUp
1.798e-02
bp_valid^4:pretrainedTRUE:augmentationMixUp
5.561e-04
bp_valid^5:pretrainedTRUE:augmentationMixUp
-7.104e-03
bp_valid^6:pretrainedTRUE:augmentationMixUp
2.117e-03
bp_valid^7:pretrainedTRUE:augmentationMixUp
-4.217e-03
bp_valid^8:pretrainedTRUE:augmentationMixUp
1.170e-03
bp_valid.L:pretrainedTRUE:augmentationCutMix
-4.333e-02
bp_valid.Q:pretrainedTRUE:augmentationCutMix
1.675e-03
bp_valid.C:pretrainedTRUE:augmentationCutMix
2.397e-02
bp_valid^4:pretrainedTRUE:augmentationCutMix
-7.490e-03
bp_valid^5:pretrainedTRUE:augmentationCutMix
-1.199e-02
bp_valid^6:pretrainedTRUE:augmentationCutMix
1.557e-02
bp_valid^7:pretrainedTRUE:augmentationCutMix
-2.549e-03
bp_valid^8:pretrainedTRUE:augmentationCutMix
-3.646e-03
Let’s now look at model predictions to get a better sense. We can see a few things:
ig_ stands for pre-trained on instagram,
but we see that pretraining did not help. For more information about the
architecture, see: https://arxiv.org/pdf/1611.05431.pdfpredictions = select(df,trans,arch,pretrained,label_smoothing,augmentation,bp_valid) %>%
distinct()
predictions$predicted_acc = sin(predict(reduced_model, predictions))
predictions = predictions %>%
arrange(-predicted_acc)
predictions %>%
split(.$bp_valid)
$`0.5`
$`1`
$`2`
$`5`
$`10`
$`20`
$`50`
$`100`
$`200`
NA
Now that we optimized training parameters, let’s evaluate the effect of sample quality. To do that, we did training using only 5 randomly chosen samples as training set, including 0-3 of the four lowest-quality samples per species. Quality was evaluated using two metrics: insert size or increase in T content throughout read length. We then evaluated, for each of the 5 samples per species left out of the training set, whether its prediction was correct.
We did 50 replicates ramdonly choosing the training set for each combination of quality metric and number of low-quality samples in the training set. Let’s now evaluate the results. Let’s start by reading the data.
df = read_csv('sample_quality.csv')[-1]
New names:Rows: 93418 Columns: 13── Column specification ───────────────────────────────────────────────────────────────────────────────────────────
Delimiter: ","
chr (6): bp_training, samples_training, qual_metric, sample_valid, valid_actual, valid_prediction
dbl (6): ...1, kmer_size, replicate, bp_valid, n_samp_training, n_lowqual_training
lgl (1): valid_lowqual
ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
df = df %>%
mutate(correct_pred = valid_actual == valid_prediction)
df
It seems that in general including some low quality samples (by the variation in content metric) may improve high-quality samples a little bit, but only increases variation of low quality samples instead of clearly improving them.
p = df %>%
filter(qual_metric == 'high_c_sd') %>%
group_by(replicate, sample_valid, n_lowqual_training) %>%
filter(bp_valid == min(bp_valid)) %>%
group_by(replicate, n_lowqual_training, valid_lowqual) %>%
summarize(mean_acc = mean(correct_pred)) %>%
mutate(valid_lowqual = c('TRUE' = 'Validation accuracy for low quality samples', 'FALSE' = 'Validation accuracy for high-quality samples')[as.character(valid_lowqual)]) %>%
ggplot() +
geom_histogram(aes(x = mean_acc), boundary = 1) +
scale_y_continuous(sec.axis = sec_axis('identity', name = 'Number of low quality samples in training set',breaks = NULL, labels = NULL, guide = NULL)) +
scale_x_continuous(limits = c(0,1)) +
xlab('Average validation accuracy across all samples') +
ylab('Frequency across replicates') +
labs(title = 'Effect of quality determined by variation in GC content on accuracy') +
facet_grid(n_lowqual_training~valid_lowqual) +
theme_few() +
theme(strip.background = element_rect(fill=gray(0.8)),
plot.title = element_text(hjust = 0.5)
)
`summarise()` has grouped output by 'replicate', 'n_lowqual_training'. You can override using the `.groups` argument.
p
ggsave(filename = 'quality_content.pdf',plot =p,device='pdf',path = 'paper_images',width = 7,height = 5,units = 'in')
ggsave(filename = 'quality_content.png',plot =p,device='png',path = 'paper_images',width = 7,height = 5,units = 'in')
NA
NA
NA
NA
The effect is less pronounced for average insert size
p = df %>%
filter(qual_metric == 'low_size') %>%
group_by(replicate, sample_valid, n_lowqual_training) %>%
filter(bp_valid == max(bp_valid)) %>%
group_by(replicate, n_lowqual_training, valid_lowqual) %>%
summarize(mean_acc = mean(correct_pred)) %>%
mutate(valid_lowqual = c('TRUE' = 'Validation accuracy for low quality samples', 'FALSE' = 'Validation accuracy for high-quality samples')[as.character(valid_lowqual)]) %>%
ggplot() +
geom_histogram(aes(x = mean_acc), boundary = 1) +
scale_y_continuous(sec.axis = sec_axis('identity', name = 'Number of low quality samples in training set',breaks = NULL, labels = NULL, guide = NULL)) +
scale_x_continuous(limits = c(0,1)) +
xlab('Average validation accuracy across all samples') +
ylab('Frequency across replicates') +
labs(title = 'Effect of sequencing quality determined by insert size on accuracy') +
facet_grid(n_lowqual_training~valid_lowqual) +
theme_few() +
theme(strip.background = element_rect(fill=gray(0.8)),
plot.title = element_text(hjust = 0.5)
)
`summarise()` has grouped output by 'replicate', 'n_lowqual_training'. You can override using the `.groups` argument.
p
ggsave(filename = 'quality_size.pdf',plot =p,device='pdf',path = 'paper_images',width = 7,height = 5,units = 'in')
ggsave(filename = 'quality_size.png',plot =p,device='png',path = 'paper_images',width = 7,height = 5,units = 'in')
What if we order all samples by their validation accuracy and compare to the quality metrics, what do we see?
df_info = read_csv('sample_info_stats.csv')[-1]
New names:Rows: 100 Columns: 11── Column specification ───────────────────────────────────────────────────────────────────────────────────────────
Delimiter: ","
chr (7): species, collector, collection, country, dna_concentration, library_id, filename_root
dbl (4): ...1, sample_number, insert_size, content_sd
ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
df_info
There seems to be a weak negative correlation between variation in content and accuracy, but many samples that seem to be good with this metric have always low accuracy.
df %>%
filter(qual_metric == 'high_c_sd') %>%
group_by(replicate, sample_valid, n_lowqual_training) %>%
filter(bp_valid == max(bp_valid)) %>%
group_by(sample_valid, n_lowqual_training) %>%
summarize(mean_acc = mean(correct_pred)) %>%
left_join(df_info,by = c('sample_valid' = 'library_id')) %>%
ggplot() +
scale_x_sqrt() +
geom_jitter(aes(x = content_sd, y = mean_acc, color = species),width = 0, height = 0.05) +
scale_color_viridis_d(option = 'turbo') +
facet_wrap(~n_lowqual_training)
`summarise()` has grouped output by 'sample_valid'. You can override using the `.groups` argument.
Again, this is less pronounced for insert size
df %>%
filter(qual_metric == 'low_size') %>%
group_by(replicate, sample_valid, n_lowqual_training) %>%
filter(bp_valid == min(bp_valid)) %>%
group_by(sample_valid, n_lowqual_training) %>%
summarize(mean_acc = mean(correct_pred)) %>%
left_join(df_info,by = c('sample_valid' = 'library_id')) %>%
ggplot() +
geom_jitter(aes(x = insert_size, y = mean_acc, color = species),width = 0, height = 0.05) +
scale_color_viridis_d(option = 'turbo') +
facet_wrap(~n_lowqual_training)
`summarise()` has grouped output by 'sample_valid'. You can override using the `.groups` argument.
What is the relationship between DNA extraction yield and library quality?
First, let’s plot against standard deviation.
p1 = df_info %>%
mutate(dna_c = ifelse(dna_concentration == 'too high', 200, dna_concentration),
dna_c = as.numeric(dna_c),
dna_c = ifelse(dna_c == 0, 0.05, dna_c)) %>%
ggplot() +
geom_point(aes(dna_c, content_sd)) +
scale_y_log10(name = 'Standard deviation in base content') +
scale_x_log10(name = 'DNA yield (ng/uL)', breaks = c(0.05,0.1,1,10,100,200), labels = c('too\nlow', 0.1, 1, 10, 100, 'too\nhigh')) +
theme_few()
p1
Now, against insert size
p2 = df_info %>%
mutate(dna_c = ifelse(dna_concentration == 'too high', 200, dna_concentration),
dna_c = as.numeric(dna_c),
dna_c = ifelse(dna_c == 0, 0.05, dna_c)) %>%
ggplot() +
geom_point(aes(dna_c, insert_size)) +
scale_y_continuous(name = 'Insert size (bp)') +
scale_x_log10(name = 'DNA yield (ng/uL)', breaks = c(0.05,0.1,1,10,100,200), labels = c('too\nlow', 0.1, 1, 10, 100, 'too\nhigh')) +
theme_few()
p2
title_plot <- ggplot() +
labs(title = "Correlation between DNA yield and quality metrics") +
theme_void() + # Remove axes, legend, etc.
theme(plot.title = element_text(hjust = 0.5, size = 12, face = "bold",vjust=1),
plot.background = element_rect(fill="white",color="white")) # Center the title
p = cowplot::plot_grid(
title_plot,
cowplot::plot_grid(p1,p2,labels = "AUTO",ncol=1),
ncol = 1,
rel_heights = c(0.05,0.95) # Adjust the relative heights as needed
)
p
ggsave(filename = 'yield_vs_quality.pdf',plot =p,device='pdf',path = 'paper_images',width = 5,height = 8.5,units = 'in')
ggsave(filename = 'yield_vs_quality.png',plot =p,device='png',path = 'paper_images',width = 5,height = 8.5,units = 'in')
Bottomline: as long as the majority of the samples for each species are high-quality, having low-quality samples in the training set should not cause much trouble and might even improve inference for some low-quality samples.
Now let’s evaluate the effect of number of samples per species.
df = read_csv('n_training.csv')[-1]
New names:Rows: 164229 Columns: 10── Column specification ───────────────────────────────────────────────────────────────────────────────────────────
Delimiter: ","
chr (5): bp_training, samples_training, sample_valid, valid_actual, valid_prediction
dbl (5): ...1, kmer_size, replicate, bp_valid, n_samp_training
ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
df = df %>%
mutate(correct_pred = valid_actual == valid_prediction)
df
Does the number of samples used in training impact the validation accuracy? Let’s plot one panel for each sample. It seems it does.
p = df %>%
group_by(n_samp_training, bp_valid, sample_valid, valid_actual) %>%
summarize(mean_acc = mean(correct_pred)) %>%
ggplot() +
#geom_jitter(aes(x = n_samp_training/10, y = mean_acc)) +
geom_boxplot(aes(x = n_samp_training/10, y = mean_acc, group = n_samp_training/10)) +
facet_wrap(valid_actual~sample_valid) +
theme_few()
`summarise()` has grouped output by 'n_samp_training', 'bp_valid', 'sample_valid'. You can override using the `.groups` argument.
p
Let’s now plot only the average accuracy for each sample across replicates, with each sample represented by a line.
It seems that more samples in the training set does help, but for most cases about 4 samples is already pretty good. Let’s plot coloring by species
df %>%
group_by(n_samp_training, sample_valid, valid_actual) %>%
summarize(mean_acc = mean(correct_pred)) %>%
mutate(valid_actual = fct_reorder(valid_actual,mean_acc)) %>%
ggplot() +
geom_line(aes(x = n_samp_training/10, y = mean_acc, group = sample_valid, color = valid_actual, linetype = valid_actual)) +
scale_color_manual(values = c(few_pal('Dark')(5),few_pal('Dark')(5))) +
scale_linetype_manual(values = rep(1:2,each = 5)) +
theme_few()
`summarise()` has grouped output by 'n_samp_training', 'sample_valid'. You can override using the `.groups` argument.
Now let’s try to use line type by sample quality instead.
df_plot = df %>%
group_by(n_samp_training, sample_valid, valid_actual) %>%
summarize(mean_acc = mean(correct_pred)) %>%
mutate(valid_actual = fct_reorder(valid_actual,mean_acc)) %>%
left_join(df_info %>%
mutate(sample_valid = paste0('S-',sample_number)) %>%
mutate(dna_concentration = ifelse(dna_concentration == 'too high',150,dna_concentration)) %>%
mutate(dna_concentration = as.numeric(dna_concentration)) %>%
mutate(highqual = dna_concentration >= quantile(dna_concentration,probs=0.5)) %>%
select(sample_valid, highqual))
`summarise()` has grouped output by 'n_samp_training', 'sample_valid'. You can override using the `.groups` argument.Joining with `by = join_by(sample_valid)`
df_ribbon = df_plot %>%
group_by(n_samp_training) %>%
summarise(q1 = quantile(mean_acc,0.25),
median = median(mean_acc),
q3 = quantile(mean_acc, 0.75))
p = ggplot(df_plot) +
stat_summary(aes(x = n_samp_training/10, y = mean_acc), fill = 'pink', fun.max = function(x){quantile(x,0.75)},fun.min = function(x){quantile(x,0.25)}, geom='ribbon') +
geom_line(aes(x = n_samp_training/10, y = mean_acc, group = sample_valid, linetype = highqual), alpha = 0.5, size = 0.25) +
stat_summary(aes(x = n_samp_training/10, y = mean_acc), color = 'red', size = 0.5, fun = 'median', geom='line') +
scale_linetype_manual(values = c('TRUE' = "solid", 'FALSE' = "51"), name = 'DNA yield', labels = c('TRUE' = 'High', 'FALSE' = 'Low')) +
scale_x_continuous(breaks=1:7) +
ylab('Average validation accuracy') +
xlab('Training samples per species') +
theme_few(base_size = 6) +
theme(legend.key.size = unit(0.2, "cm"))
p
The graph is a little cluttered, let’s now do a version for the final figure in the paper:
df_facet_plot = df_plot %>%
ungroup %>%
left_join(select(read_csv('sample_info_stats.csv'), sample_valid = library_id, content_sd)) %>%
mutate(dna_quality = ntile(1-content_sd, 100)) %>%
mutate(valid_actual = fct_reorder(valid_actual,mean_acc,.fun = mean,.desc = T))
New names:Rows: 100 Columns: 11── Column specification ───────────────────────────────────────────────────────────────────────────────────────────
Delimiter: ","
chr (7): species, collector, collection, country, dna_concentration, library_id, filename_root
dbl (4): ...1, sample_number, insert_size, content_sd
ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.Joining with `by = join_by(sample_valid)`
p = ggplot(df_facet_plot) +
#stat_summary(aes(x = n_samp_training/10, y = mean_acc), fill = gray(.8), fun.max = function(x){quantile(x,0.75, type =4)},fun.min = function(x){quantile(x,0.25,type =4)}, geom='ribbon') +
geom_line(aes(x = n_samp_training/10, y = mean_acc, group = sample_valid, color = dna_quality), alpha = 0.5) +
stat_summary(aes(x = n_samp_training/10, y = mean_acc), color = 'black', size = 0.5, linetype = 'dashed', fun = 'mean', geom='line') +
scale_color_viridis_c(name ='DNA quality rank') +
#scale_linetype_manual(values = c('TRUE' = "solid", 'FALSE' = "51"), name = 'DNA quality', labels = c('TRUE' = 'High', 'FALSE' = 'Low')) +
scale_x_continuous(breaks=1:7) +
ylab('Average validation accuracy') +
xlab('Training samples per species') +
facet_wrap(~valid_actual,nrow = 1) +
theme_few(base_size = 6) +
theme(legend.key.size = unit(0.2, "cm"))
p
ggsave(filename = 'n_samples.png',plot =p,device='png',path = 'paper_images',width = 16,height = 5,units = 'cm',dpi = 2400)
Finally, let’s plot the actual varKodes for Sitgmaphyllon, each species in row, ordered by quality. We start by generating the appropriate dataframe:
find_image = function(sample_id){
x = list.files(path = 'images_7',pattern = paste0("^.+\\+",sample_id,"_.+"),full.names = T)
return(x[length(x)])
}
df_varKode_plot = df_facet_plot %>%
filter(n_samp_training==70) %>%
select(sample_valid,valid_actual,content_sd,mean_acc) %>%
distinct() %>%
rowwise() %>%
mutate(image_path=find_image(sample_valid)) %>%
group_by(valid_actual) %>%
arrange(-content_sd) %>%
mutate(quality_rank=1:n()) %>%
ungroup() %>%
arrange(valid_actual,quality_rank) %>%
mutate(valid_actual = str_replace_all(valid_actual,"_",". "))
df_varKode_plot
Now let’s plot
p = ggplot(df_varKode_plot, aes(x=quality_rank, y=valid_actual)) +
geom_tile(aes(fill=mean_acc), color="white") +
scale_fill_viridis_c("Average validation accuracy", option = 'magma', limits = c(0,1),labels=scales::percent) +
ggimage::geom_image(aes(image=image_path), size=0.09) +
coord_equal() +
theme_minimal() +
ggtitle(expression(paste("varKodes for species of ", italic("Stigmaphyllon")))) +
xlab("DNA quality rank") +
ylab("Actual species") +
scale_x_discrete() +
theme(plot.background = element_rect(fill="white", color = "white"),
panel.background = element_rect(fill = "white", color = "white"),
plot.title = element_text(hjust=0.5),
panel.grid = element_blank(),
axis.text.y = element_text(face='italic'),
legend.position = 'bottom')
print(p)
# Save the plot
ggsave(filename = 'varkodes_quality.png', plot = p, device = 'png', path = 'paper_images', width = 7, height = 7, units = 'in', dpi = 1200)
ggsave(filename = 'varkodes_quality.pdf', plot = p, device = 'pdf', path = 'paper_images', width = 7, height = 7, units = 'in')